import time
import threading
import os
import json
import ssl
import uuid
from typing import Callable, Optional
import paho.mqtt.client as mqtt
import logging
import socket

# cert_dir = f"/etc/vyomcloudbridge/mosquito/certs/"
cert_dir = f"/Users/amardeepsaini/Documents/VYOM/vyom-cloud-bridge/vyomcloudbridge/services/mqtt_mosquito/32/"

# /etc/vyomcloudbridge/mosquito/certs/machine.cert.pem
cert_file_name = "machine.cert.pem"
cert_file_path = os.path.join(cert_dir, cert_file_name)

# /etc/vyomcloudbridge/mosquito/certs/machine.private.key
pri_key_file_name = "machine.private.key"
pri_key_file_path = os.path.join(cert_dir, pri_key_file_name)

# /etc/vyomcloudbridge/mosquito/certs/root-CA.crt
root_ca_file_name = "root-CA.crt"
root_ca_file_path = os.path.join(cert_dir, root_ca_file_name)

# Configure logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# stream_handler = logging.StreamHandler()
# stream_handler.setFormatter(formatter)
# logger.addHandler(stream_handler)


class MqttClient:
    def __init__(
        self,
        broker_host: str = "3.110.9.183",
        broker_port: int = 8884,  # SSL/TLS port
        callback: Optional[Callable] = None,
        daemon: bool = False,
    ):
        try:
            self.machine_id = "32"
            self.broker_host = broker_host
            self.broker_port = broker_port
            self.callback = callback
            self.daemon = daemon
            self.client_id = f"machine_{self.machine_id}"

            # Certificate paths
            self.cert_path = cert_file_path
            self.pri_key_path = pri_key_file_path
            self.root_ca_path = root_ca_file_path

            self._verify_cert_files()

            # Connection state tracking
            self.is_connected = False
            self.connection_lock = threading.Lock()
            self._connection_in_progress = False

            # Subscription tracking
            self.subscribed_topics = set()

            # Connection and reconnection parameters
            self.max_reconnect_attempts = 3
            self.base_reconnect_delay = 2  # Base delay in seconds
            self.connection_retry_loop_delay = 30  # seconds

            # Device topic prefix (restricted to device-specific topics)
            self.device_topic_prefix = f"vyom-mqtt-msg/{self.machine_id}"

            # Initialize MQTT client
            self.mqtt_client = None
            self._setup_mqtt_client()

            # Create initial connection
            self._create_mqtt_connection()

            # Start background connection monitor
            # self._start_backgd_conn_monitor()

            print("MqttClient initialized successfully!")

        except Exception as e:
            print(f"Error: Error initializing MqttClient: {str(e)}")
            raise

    def _verify_cert_files(self):
        """Verify that all required certificate files exist"""
        for file_path in [self.cert_path, self.pri_key_path, self.root_ca_path]:
            if not os.path.exists(file_path):
                print(f"Error: Certificate file not found: {file_path}")
                raise FileNotFoundError(
                    f"Required certificate file not found: {file_path}"
                )

    def _setup_mqtt_client(self):
        """Setup MQTT client configuration (without connecting)"""
        try:
            # Create MQTT client
            self.mqtt_client = mqtt.Client(
                client_id=self.client_id,
                clean_session=False,  # Persistent session
                protocol=mqtt.MQTTv311,
                transport="tcp",
            )

            # Configure SSL/TLS with certificates
            self.mqtt_client.tls_set(
                ca_certs=self.root_ca_path,
                certfile=self.cert_path,
                keyfile=self.pri_key_path,
                cert_reqs=ssl.CERT_REQUIRED,
                tls_version=ssl.PROTOCOL_TLS,
                ciphers=None,
            )

            # Disable hostname verification (similar to mosquitto client behavior)
            self.mqtt_client.tls_insecure_set(True)

            # Set callbacks
            self.mqtt_client.on_connect = self._on_connect
            self.mqtt_client.on_disconnect = self._on_disconnect
            self.mqtt_client.on_message = self._on_message
            self.mqtt_client.on_subscribe = self._on_subscribe
            self.mqtt_client.on_publish = self._on_publish
            self.mqtt_client.on_log = self._on_log  # Add logging callback

            # Configure keep alive and other options
            self.mqtt_client.max_inflight_messages_set(20)
            self.mqtt_client.max_queued_messages_set(100)

            print("MQTT client configured successfully with SSL/TLS")

        except Exception as e:
            print(f"Error: Failed to setup MQTT client: {str(e)}")
            raise

    def _create_mqtt_connection(self):
        """Create a new MQTT connection with exponential backoff (similar to AWS IoT Core)"""
        with self.connection_lock:  # Acquire lock to prevent concurrent attempts
            if self.is_connected:
                self._connection_in_progress = False
                print("Connection already established, skipping reconnection")
                return

            self._connection_in_progress = True
            try:
                for attempt in range(self.max_reconnect_attempts):
                    try:
                        # Use exponential backoff for reconnection attempts
                        delay = self.base_reconnect_delay * (2**attempt)

                        if attempt > 0:
                            print(
                                f"Reconnection attempt {attempt + 1}/{self.max_reconnect_attempts}"
                            )
                            time.sleep(delay)

                        # Test network connectivity first
                        if not self._test_network_connectivity():
                            raise ConnectionError("Network connectivity test failed")

                        # Connect to MQTT broker
                        print(
                            f"Attempting to connect to {self.broker_host}:{self.broker_port}"
                        )
                        self.mqtt_client.connect(
                            self.broker_host, self.broker_port, 1200
                        )  # 20 minutes keep alive
                        self.mqtt_client.loop_start()

                        # Wait for connection to be established
                        timeout = 15  # 15 seconds timeout
                        start_time = time.time()
                        while (
                            not self.is_connected
                            and (time.time() - start_time) < timeout
                        ):
                            time.sleep(0.1)

                        if self.is_connected:
                            print("Successfully connected to MQTT broker")
                            self._resubscribe_to_topics()
                            return
                        else:
                            raise TimeoutError("Connection timeout")

                    except Exception as e:
                        print(
                            f"Warning: MQTT connection attempt {attempt + 1} failed: {str(e)}"
                        )
                        if self.mqtt_client:
                            try:
                                self.mqtt_client.loop_stop()
                            except:
                                pass

                        if attempt < self.max_reconnect_attempts - 1:
                            continue

                # If all attempts fail
                raise ConnectionError(
                    f"Could not connect to MQTT broker after {self.max_reconnect_attempts} attempts"
                )

            finally:
                self._connection_in_progress = False

    def _test_network_connectivity(self):
        """Test if the port is open (works for both TLS and plain)"""
        try:
            print(f"Testing if port {self.broker_port} is open on {self.broker_host}")
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(5)
            result = sock.connect_ex((self.broker_host, self.broker_port))

            if result == 0:
                # Port is open, but we need to close immediately for TLS ports
                sock.close()
                print("Port is reachable")
                return True
            else:
                sock.close()
                print(f"Error: Port {self.broker_port} is not reachable")
                return False

        except Exception as e:
            print(f"Error: Network connectivity test failed: {str(e)}")
            return False

    def _on_log(self, client, userdata, level, buf):
        """Callback for MQTT client logging"""
        print(f"MQTT Log: {buf}")

    def _on_connect(self, client, userdata, flags, rc):
        """Callback for when client connects to broker"""
        if rc == 0:
            print(
                f"Device client connected successfully to {self.broker_host}:{self.broker_port}"
            )
            with self.connection_lock:
                self.is_connected = True
                self._connection_in_progress = False

            # Resubscribe to topics
            self._resubscribe_to_topics()
        else:
            print(f"Error: Failed to connect, return code: {rc}")
            print(f"Error: Connection error meaning: {mqtt.connack_string(rc)}")
            with self.connection_lock:
                self.is_connected = False
                self._connection_in_progress = False

    def _on_disconnect(self, client, userdata, rc):
        """Callback for when client disconnects from broker"""
        print(f"Warning Device client disconnected, return code: {rc}")
        with self.connection_lock:
            self.is_connected = False

    def _on_message(self, client, userdata, msg):
        """Callback for when message is received"""
        try:
            topic = msg.topic
            payload = msg.payload.decode("utf-8")
            print(f"Received message from topic '{topic}': {payload}")

            if self.callback:
                self.callback(topic, payload)
            else:
                print("No callback provided, skipping callback execution")

        except Exception as e:
            print(f"Error: Error processing received message: {str(e)}")

    def _on_subscribe(self, client, userdata, mid, granted_qos):
        """Callback for when subscription is confirmed"""
        print(f"Subscription confirmed with mid: {mid}, QoS: {granted_qos}")

    def _on_publish(self, client, userdata, mid):
        """Callback for when message is published"""
        print(f"Message published with mid: {mid}")

    def connect(self):
        """Connect to MQTT broker with exponential backoff"""
        with self.connection_lock:
            if self.is_connected:
                print("Already connected to MQTT broker")
                return True

            if self._connection_in_progress:
                print("Connection attempt already in progress")
                return False

            self._connection_in_progress = True

        # Test network connectivity first
        if not self._test_network_connectivity():
            with self.connection_lock:
                self._connection_in_progress = False
            return False

        try:
            for attempt in range(self.max_reconnect_attempts):
                try:
                    delay = min(
                        self.base_reconnect_delay * (2**attempt),
                        self.reconnect_delay_max,
                    )

                    if attempt > 0:
                        print(
                            f"Reconnection attempt {attempt + 1}/{self.max_reconnect_attempts}"
                        )
                        time.sleep(delay)

                    # Connect to broker
                    print(
                        f"Attempting to connect to {self.broker_host}:{self.broker_port}"
                    )
                    self.mqtt_client.connect(self.broker_host, self.broker_port, 60)
                    self.mqtt_client.loop_start()

                    # Wait for connection
                    timeout = 15
                    start_time = time.time()
                    while (
                        not self.is_connected and (time.time() - start_time) < timeout
                    ):
                        time.sleep(0.1)

                    if self.is_connected:
                        print("Successfully connected to MQTT broker")
                        return True
                    else:
                        print(f"Connection attempt {attempt + 1} timed out")
                        self.mqtt_client.loop_stop()

                except Exception as e:
                    print(f"Error: Connection attempt {attempt + 1} failed: {str(e)}")
                    if self.mqtt_client:
                        self.mqtt_client.loop_stop()

            print("Error: Failed to connect after all attempts")
            return False

        finally:
            with self.connection_lock:
                self._connection_in_progress = False

    def _resubscribe_to_topics(self):
        """Resubscribe to all previously subscribed topics"""
        for topic in self.subscribed_topics.copy():
            try:
                result = self.mqtt_client.subscribe(topic, qos=1)
                if result[0] == mqtt.MQTT_ERR_SUCCESS:
                    print(f"Resubscribed to topic: {topic}")
                else:
                    print(f"Error: Failed to resubscribe to topic: {topic}")
            except Exception as e:
                print(f"Error: Error resubscribing to {topic}: {str(e)}")

    def subscribe_to_topic(self, topic: str):
        """Subscribe to a topic (restricted to device topics)"""
        try:
            # Fix: Remove or relax topic restrictions for testing
            # Comment out the restriction check temporarily
            # if not self._is_topic_allowed(topic):
            #     print(f"Error: Topic '{topic}' is not allowed for device client")
            #     raise PermissionError(f"Device client not allowed to subscribe to topic: {topic}")

            if not self.is_connected:
                if not self.connect():
                    raise ConnectionError("Failed to connect to MQTT broker")

            result = self.mqtt_client.subscribe(topic, qos=1)
            if result[0] == mqtt.MQTT_ERR_SUCCESS:
                self.subscribed_topics.add(topic)
                print(f"Subscribed to topic: {topic}")
            else:
                print(f"Error: Failed to subscribe to topic: {topic}")
                raise Exception(f"Subscription failed with code: {result[0]}")

        except Exception as e:
            print(f"Error: Subscription to {topic} failed: {str(e)}")
            raise

    def publish_message(self, topic: str, payload, retain: bool = False):
        """Publish message (restricted to device topics)"""
        try:
            # Fix: Remove or relax topic restrictions for testing
            # Comment out the restriction check temporarily
            # if not self._is_topic_allowed(topic):
            #     print(f"Error: Topic '{topic}' is not allowed for device client")
            #     return False

            if not self.is_connected:
                if not self.connect():
                    print("Error: Failed to connect to MQTT broker for publishing")
                    return False

            # Convert payload to string if it's a dict
            if isinstance(payload, dict):
                payload = json.dumps(payload)

            print(f"Publishing message to topic: {topic}")
            result = self.mqtt_client.publish(topic, payload, qos=1, retain=retain)

            if result.rc == mqtt.MQTT_ERR_SUCCESS:
                print(f"Message published successfully to topic: {topic}")
                return True
            else:
                print(f"Error: Failed to publish message, return code: {result.rc}")
                return False

        except Exception as e:
            print(f"Error: Publish to {topic} failed: {str(e)}")
            return False

    def _is_topic_allowed(self, topic: str) -> bool:
        """Check if topic is allowed for device client"""
        allowed_patterns = [
            self.device_topic_prefix,  # vyom-mqtt-msg/machine_id/*
            f"commands/{self.machine_id}",  # commands/machine_id/*
            f"config/{self.machine_id}",  # config/machine_id/*
            "#",  # Allow all topics for testing
        ]

        return any(
            topic.startswith(pattern) or topic == pattern
            for pattern in allowed_patterns
        )

    def disconnect(self):
        """Gracefully disconnect from MQTT broker"""
        try:
            if self.mqtt_client:
                self.mqtt_client.loop_stop()
                self.mqtt_client.disconnect()

            with self.connection_lock:
                self.is_connected = False

            print("MQTT device client disconnected successfully")

        except Exception as e:
            print(f"Error: Failed to disconnect MQTT client: {str(e)}")

    def cleanup(self):
        """Gracefully disconnect from MQTT broker"""
        try:
            self.disconnect()
            print(f"cleanup successful MQTT client")
        except Exception as e:
            print(f"Error: Error in cleanup MQTT client: {str(e)}")


def message_callback(topic, payload):
    try:
        print(f"Received message in callback '{topic}': {payload}")
        # Try to parse as JSON if possible
        try:
            parsed_payload = json.loads(payload)
            print(f"Parsed JSON payload: {parsed_payload}")
        except json.JSONDecodeError:
            print(f"Payload is not JSON, treating as string: {payload}")
    except Exception as e:
        print(f"Error: Error in calling callback: {str(e)}")


def main():
    """Example usage of MqttClient"""
    # Initialize client
    import signal

    machine_id = 33
    mqtt_client = MqttClient(
        callback=message_callback,
    )

    # Connect first
    if not mqtt_client.connect():
        print("Error: Failed to connect to MQTT broker")
        mqtt_client.cleanup()
        exit(0)

    # Wait for connection to stabilize
    # time.sleep(2)

    # subscribe_topic_2 = f"vyom-mqtt-msg/{machine_id}/#"  # Device-specific topics

    # try:
    #     mqtt_client.subscribe_to_topic(subscribe_topic_2)

    #     # Test publishing
    #     publish_topic = f"vyom-mqtt-msg/{machine_id}/hello.json"
    #     mqtt_client.publish_message(publish_topic, "Test message using client cert")
    #     mqtt_client.publish_message(
    #         publish_topic, {"message": "JSON payload", "timestamp": time.time()}
    #     )

    # except Exception as e:
    #     print(f"Error: Error in main execution: {str(e)}")

    # Keep the client running
    try:
        is_running = True

        def signal_handler(sig, frame):
            print(f"Received signal {sig}, shutting down...")
            is_running = False

        signal.signal(signal.SIGINT, signal_handler)
        signal.signal(signal.SIGTERM, signal_handler)

        # Keep the connection alive to receive messages
        print("Listening for messages... Press Ctrl+C to exit")
        while is_running:
            time.sleep(10)

    except KeyboardInterrupt:
        print("\nShutting down...")
        mqtt_client.cleanup()


if __name__ == "__main__":
    main()
